{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/miniconda3/envs/yuh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:13<00:00,  3.34s/it]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "  0%|          | 0/125 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 125/125 [31:38<00:00, 15.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [8.010628700256348, 11.617366790771484, 16.436626434326172, 12.452149391174316, 5.280788421630859, 15.262550354003906, 1.1191734075546265, 8.966885566711426, 5.377917289733887, 6.820733070373535, 10.366116523742676, 5.997546672821045, 9.961697578430176, 5.420133590698242, 9.82841682434082, 10.09469985961914, 6.015024662017822, 7.051065444946289, 12.503717422485352, 5.131502628326416, 10.94749641418457, 11.934362411499023, 7.664680004119873, 15.18045711517334, 12.17275333404541, 11.387449264526367, 10.594610214233398, 9.220525741577148, 9.481122016906738, 5.799471378326416, 4.850359916687012, 8.029363632202148, 11.793526649475098, 4.3261590003967285, 16.842838287353516, 10.726288795471191, 8.31452751159668, 12.031231880187988, 11.358588218688965, 11.064593315124512, 1.0780762434005737, 11.942229270935059, 9.76633358001709, 13.83141803741455, 13.307634353637695, 8.395865440368652, 9.22557544708252, 7.7516679763793945, 11.457433700561523, 9.409688949584961, 6.23565149307251, 13.601834297180176, 11.01682186126709, 13.83610725402832, 19.9516658782959, 9.33217716217041, 15.78387451171875, 7.547257423400879, 11.465177536010742, 10.141702651977539, 7.7508015632629395, 4.400157928466797, 8.206475257873535, 3.9529998302459717, 9.017908096313477, 12.808181762695312, 8.382381439208984, 8.570204734802246, 5.058107376098633, 12.327993392944336, 10.220998764038086, 9.337959289550781, 10.435149192810059, 5.0961103439331055, 16.522897720336914, 6.60452127456665, 8.885839462280273, 4.4701104164123535, 12.076196670532227, 15.767400741577148, 13.618545532226562, 6.078358173370361, 18.256649017333984, 10.73677921295166, 7.4286322593688965, 9.502052307128906, 10.708578109741211, 21.418699264526367, 6.243850231170654, 8.858406066894531, 7.630718231201172, 4.250406265258789, 9.741896629333496, 6.861844062805176, 8.790027618408203, 5.581715106964111, 11.816938400268555, 21.221420288085938, 10.870670318603516, 13.798497200012207, 12.270033836364746, 12.0072660446167, 7.87317419052124, 8.445389747619629, 12.166735649108887, 8.811786651611328, 8.933721542358398, 10.440140724182129, 8.769749641418457, 15.67431640625, 5.940704822540283, 14.466821670532227, 13.264716148376465, 8.836137771606445, 11.988678932189941, 3.4359548091888428, 10.201058387756348, 8.471302032470703, 6.930070400238037, 8.382186889648438, 22.832719802856445, 2.8186168670654297, 13.11449146270752, 4.754410743713379, 9.033476829528809, 10.557740211486816, 8.977855682373047, 8.036605834960938, 5.353160858154297, 17.625993728637695, 10.17125415802002, 7.173891544342041, 8.766822814941406, 13.550002098083496, 7.840097904205322, 3.1190133094787598, 11.160945892333984, 13.710724830627441, 1.1064238548278809, 7.5461812019348145, 12.3268404006958, 7.709891319274902, 8.22543716430664, 4.163880825042725, 2.9265434741973877, 18.814067840576172, 8.697185516357422, 13.269853591918945, 12.02336597442627, 4.548267841339111, 1.116442084312439, 5.5470123291015625, 19.74461555480957, 5.3078484535217285, 5.369590759277344, 8.341388702392578, 9.720060348510742, 12.364195823669434, 13.543380737304688, 14.620802879333496, 10.545209884643555, 1.199769377708435, 6.551154613494873, 4.261473178863525, 1.143824815750122, 6.527069568634033, 10.181029319763184, 15.910719871520996, 1.0668494701385498, 4.404050350189209, 8.460091590881348, 1.0691184997558594, 8.358770370483398, 10.14790153503418, 6.419498920440674, 11.074981689453125, 15.543022155761719, 9.20804500579834, 9.49456787109375, 14.398754119873047, 6.803823947906494, 9.733506202697754, 9.115334510803223, 15.631030082702637, 13.37239933013916, 13.936918258666992, 7.680768966674805, 9.341581344604492, 14.949332237243652, 6.040980815887451, 10.193159103393555, 12.26705265045166, 8.207416534423828, 11.152083396911621, 8.511012077331543, 19.04939079284668, 9.881658554077148, 7.469366550445557, 9.914152145385742, 6.637253761291504, 8.573420524597168, 16.561283111572266, 9.947267532348633, 12.261446952819824, 8.632237434387207, 8.450776100158691, 13.382251739501953, 9.926886558532715, 9.811447143554688, 4.089788436889648, 14.922588348388672, 8.6519193649292, 10.837435722351074, 2.244518280029297, 17.888185501098633, 12.210362434387207, 13.230228424072266, 9.490396499633789, 12.75918960571289, 12.435782432556152, 17.8333683013916, 11.419983863830566, 9.84284496307373, 4.607920169830322, 19.599624633789062, 2.974491834640503, 8.909934997558594, 13.03402328491211, 11.066923141479492, 9.668963432312012, 10.902055740356445, 5.762572288513184, 13.158354759216309, 13.008326530456543, 10.108842849731445, 11.134784698486328, 7.267085552215576, 1.062124490737915, 9.775235176086426, 11.904115676879883, 8.012434005737305, 6.0598907470703125, 9.94332504272461, 4.950897693634033, 13.354955673217773, 11.400622367858887, 9.555484771728516, 1.1731740236282349, 1.0657352209091187, 10.993343353271484, 6.526697158813477, 10.931946754455566, 7.76903772354126, 5.683549880981445, 9.318018913269043, 17.337989807128906, 3.567012071609497, 5.882146835327148, 10.120692253112793, 8.446206092834473, 10.48293685913086, 8.511122703552246, 7.682855129241943, 12.188946723937988, 11.659037590026855, 3.315966844558716, 21.582555770874023, 8.081626892089844, 11.35413646697998, 13.258367538452148, 4.129848957061768, 7.471471309661865, 7.907629013061523, 1.2784677743911743, 12.50883674621582, 14.381888389587402, 12.157588005065918, 10.195441246032715, 9.613504409790039, 7.937291145324707, 3.625206708908081, 19.378389358520508, 10.435039520263672, 5.523228645324707, 8.23906421661377, 10.063398361206055, 9.412212371826172, 10.569647789001465, 5.994065284729004, 4.647335052490234, 9.233803749084473, 10.607669830322266, 8.229901313781738, 12.516242027282715, 8.875160217285156, 13.094060897827148, 6.881019115447998, 10.895951271057129, 13.319835662841797, 8.504293441772461, 9.657577514648438, 4.872453689575195, 10.11726188659668, 9.23832893371582, 16.165945053100586, 9.356475830078125, 7.49825382232666, 5.798728942871094, 3.105275869369507, 10.747064590454102, 10.2562894821167, 8.001354217529297, 22.503982543945312, 5.24068546295166, 7.696950435638428, 7.2454094886779785, 11.604303359985352, 9.304656028747559, 7.435369491577148, 9.298662185668945, 1.2701811790466309, 14.491539001464844, 8.861517906188965, 6.445970058441162, 1.0804553031921387, 8.424399375915527, 9.61799144744873, 5.793432712554932, 10.875128746032715, 12.418034553527832, 10.253976821899414, 1.1157675981521606, 7.712555408477783, 9.719476699829102, 8.144410133361816, 3.856320381164551, 3.70074725151062, 6.96307897567749, 13.27652359008789, 7.9877777099609375, 7.967118740081787, 20.568023681640625, 18.40328598022461, 17.376243591308594, 8.904082298278809, 16.3880672454834, 9.163630485534668, 10.113031387329102, 9.766498565673828, 6.99643611907959, 9.916024208068848, 10.395048141479492, 10.563835144042969, 9.037338256835938, 6.473631381988525, 7.846444606781006, 10.07083797454834, 8.506329536437988, 6.640144348144531, 15.597257614135742, 10.021827697753906, 9.007144927978516, 6.970981597900391, 8.890522956848145, 3.6563777923583984, 17.73195457458496, 9.31417179107666, 13.906554222106934, 7.458512783050537, 9.08506965637207, 6.842899799346924, 9.4472074508667, 5.030135154724121, 5.537978649139404, 9.684389114379883, 7.031853675842285, 11.963092803955078, 16.739397048950195, 11.750349998474121, 8.044574737548828, 7.697704792022705, 7.654092311859131, 4.5856733322143555, 12.900567054748535, 14.061549186706543, 8.960254669189453, 9.094598770141602, 5.286368370056152, 8.498762130737305, 1.0755640268325806, 8.672670364379883, 6.603630542755127, 12.043109893798828, 9.753754615783691, 5.050571441650391, 5.153223991394043, 19.42385482788086, 5.445138931274414, 14.174448013305664, 9.776571273803711, 9.310016632080078, 12.82960033416748, 16.950855255126953, 12.06783676147461, 7.532215118408203, 7.842898368835449, 12.698278427124023, 8.546833992004395, 13.425172805786133, 6.534655570983887, 7.1029462814331055, 7.320761203765869, 8.66125202178955, 5.198532581329346, 9.466780662536621, 4.278669834136963, 12.723271369934082, 9.128168106079102, 10.217063903808594, 10.511289596557617, 6.084582805633545, 5.3642401695251465, 3.562915325164795, 8.426335334777832, 10.374357223510742, 8.778623580932617, 8.153676986694336, 14.400634765625, 13.160099029541016, 10.261362075805664, 7.381282329559326, 12.495804786682129, 5.583824157714844, 13.270204544067383, 11.37856674194336, 12.844073295593262, 6.6208577156066895, 15.05972957611084, 9.112695693969727, 7.232944488525391, 6.5740275382995605, 11.650289535522461, 5.275041103363037, 5.649501800537109, 15.403559684753418, 8.31787109375, 20.40690803527832, 3.069023609161377, 4.900405406951904, 8.094378471374512, 1.1300125122070312, 9.17198657989502, 9.923909187316895, 19.103694915771484, 14.408270835876465, 8.907293319702148, 18.07341957092285, 4.672732830047607, 9.250576972961426, 6.204880237579346, 9.781572341918945, 6.556175708770752, 10.099235534667969, 14.56065845489502, 5.71718168258667, 4.607837677001953, 9.806540489196777, 4.647514820098877, 1.9626755714416504, 8.297572135925293, 8.350302696228027, 12.681488037109375, 8.448878288269043, 5.525542259216309, 12.167974472045898, 6.690460205078125, 8.629075050354004, 10.834916114807129, 10.170345306396484, 13.171273231506348, 10.43557357788086, 6.119800090789795, 9.315835952758789, 5.448638916015625, 18.597383499145508, 3.714162826538086, 5.895522594451904, 4.322361469268799, 12.075030326843262, 7.262693881988525, 10.784793853759766, 8.453895568847656, 8.712320327758789, 10.948339462280273, 26.107072830200195, 10.307668685913086, 6.921197891235352, 5.980469226837158, 11.404870986938477, 7.62398624420166], 'mean_perplexity': 9.497796083211899}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "\n",
    "from evaluate import load\n",
    "from pathlib import Path\n",
    "\n",
    "perplexity = load(\"perplexity\", module_type=\"metric\")\n",
    "predictions = []\n",
    "\n",
    "predict_dir = \"checkpoints/vanilla-llama3-nlp_rw-epoch30-8192-full_finetune/lightning_logs/version_1/predict\"\n",
    "predict_files = list(Path(predict_dir).glob(\"*.txt\"))\n",
    "for file in predict_files:\n",
    "    with open(file, \"r\") as f:\n",
    "        predictions.append(f.read().strip())\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model_id='/hpc2hdd/home/xzou428/Yuhao/llama3-8b-instruct', device=\"cuda\", batch_size=4)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/hpc2hdd/home/xzou428/miniconda3/envs/yuh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.99s/it]\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "  0%|          | 0/125 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 125/125 [33:19<00:00, 16.00s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'perplexities': [2.5904932022094727, 4.021300315856934, 2.261462450027466, 6.2243428230285645, 4.666659355163574, 4.675596237182617, 3.9228248596191406, 5.687806606292725, 2.777909278869629, 3.544769763946533, 11.398309707641602, 4.284066200256348, 4.431798934936523, 2.9192440509796143, 2.6259055137634277, 3.648648500442505, 4.147417068481445, 5.530505180358887, 7.727009296417236, 2.278607130050659, 1.0502769947052002, 5.2223920822143555, 3.7185781002044678, 4.1539082527160645, 3.721130847930908, 4.4218854904174805, 4.295907974243164, 3.5394885540008545, 4.444467544555664, 3.042696714401245, 3.10427188873291, 3.3197951316833496, 4.214424133300781, 3.902693748474121, 4.469727516174316, 4.2982177734375, 5.405079364776611, 4.857685089111328, 4.935823917388916, 3.835064649581909, 3.563739776611328, 5.0870819091796875, 3.870135545730591, 5.654897689819336, 5.11680793762207, 4.494389533996582, 3.4346182346343994, 3.1822702884674072, 3.9255173206329346, 4.649531364440918, 3.444899797439575, 3.8123300075531006, 3.846590995788574, 5.222723960876465, 4.871662139892578, 4.793495178222656, 5.195714950561523, 4.269996166229248, 3.8041508197784424, 4.204948902130127, 4.582414627075195, 2.9781432151794434, 4.384022235870361, 4.996068000793457, 3.3283963203430176, 4.995181560516357, 3.7345149517059326, 3.852226972579956, 3.288252592086792, 13.946520805358887, 4.415710926055908, 4.3240885734558105, 5.5196099281311035, 7.5791850090026855, 6.388134956359863, 2.2077720165252686, 5.76879358291626, 2.993288040161133, 4.854260444641113, 3.8345353603363037, 4.037947654724121, 3.179616689682007, 4.935528755187988, 5.04789400100708, 6.560720443725586, 3.3811023235321045, 3.6874942779541016, 7.379669666290283, 3.785728931427002, 4.547407150268555, 3.5501019954681396, 2.983952522277832, 3.0701754093170166, 3.6599485874176025, 4.576124668121338, 5.7997026443481445, 4.8055419921875, 5.455244541168213, 5.047201156616211, 5.181970119476318, 5.270913600921631, 4.858402252197266, 2.9926390647888184, 4.976014137268066, 4.6026482582092285, 4.633013725280762, 3.414996862411499, 6.659355640411377, 3.884490966796875, 5.321488380432129, 4.878677845001221, 6.348527431488037, 6.3732829093933105, 4.192233562469482, 6.5283002853393555, 6.7295756340026855, 1.0479986667633057, 3.5640504360198975, 6.344956874847412, 5.0100579261779785, 7.200707912445068, 2.916060209274292, 4.6285786628723145, 6.735720634460449, 3.35170578956604, 4.589212894439697, 5.393253803253174, 5.289468288421631, 3.7621090412139893, 9.246577262878418, 1.0753803253173828, 3.832934617996216, 4.822214603424072, 4.446679592132568, 3.916062831878662, 5.434490203857422, 5.257340431213379, 11.888200759887695, 4.76943826675415, 2.7659754753112793, 3.620274782180786, 2.936845064163208, 5.115570068359375, 1.0487838983535767, 3.740358829498291, 5.2107462882995605, 3.7829606533050537, 7.912901401519775, 7.10022497177124, 3.2284817695617676, 3.870907783508301, 4.517818450927734, 7.101657867431641, 4.000471591949463, 3.3696911334991455, 5.522930145263672, 4.064764976501465, 5.43998384475708, 7.000139236450195, 4.3464274406433105, 3.3972809314727783, 2.8539369106292725, 4.41402006149292, 3.7746782302856445, 4.189809322357178, 8.064919471740723, 2.7197105884552, 9.81044864654541, 3.659280300140381, 4.416683673858643, 4.311310768127441, 2.199829578399658, 2.8789360523223877, 5.026443958282471, 4.084171772003174, 4.372322082519531, 7.142247200012207, 2.914579391479492, 4.0626630783081055, 6.038825511932373, 4.812328338623047, 4.945138931274414, 7.086803913116455, 6.116292953491211, 4.474534034729004, 6.236203193664551, 6.042630195617676, 1.0457254648208618, 5.626357555389404, 5.3768744468688965, 4.301330089569092, 4.9731855392456055, 4.925156593322754, 5.859801769256592, 5.790477752685547, 6.85867977142334, 3.5453476905822754, 3.6990256309509277, 5.703799247741699, 2.968090295791626, 5.478133201599121, 7.924966335296631, 2.2905795574188232, 3.6248741149902344, 3.7600038051605225, 6.057468414306641, 6.365182399749756, 5.887929916381836, 3.847196340560913, 3.1657135486602783, 5.980282306671143, 4.461937427520752, 3.818176507949829, 5.9718546867370605, 5.993371963500977, 4.598455905914307, 1.05963933467865, 7.685187339782715, 4.059790134429932, 6.711164474487305, 1.111208438873291, 3.7667975425720215, 4.92750358581543, 3.240260124206543, 3.587874412536621, 2.91495680809021, 3.9441659450531006, 5.066460609436035, 5.234063625335693, 4.749260425567627, 4.501170635223389, 3.876809597015381, 2.837257146835327, 4.728872299194336, 5.031030178070068, 5.336129665374756, 3.72556734085083, 3.3835034370422363, 5.328119277954102, 3.3911280632019043, 3.9747259616851807, 4.9451375007629395, 4.899710178375244, 4.323750019073486, 6.875537395477295, 5.0713210105896, 3.7343833446502686, 2.6276814937591553, 6.849911212921143, 6.728526592254639, 5.074851036071777, 3.1808922290802, 3.4872894287109375, 4.249349117279053, 4.456302642822266, 5.40125036239624, 2.2727303504943848, 4.307484149932861, 3.7334656715393066, 3.144000768661499, 6.077638626098633, 3.4268453121185303, 4.402128219604492, 7.178847789764404, 3.7609658241271973, 4.005910396575928, 5.934648036956787, 4.631855010986328, 3.212538957595825, 3.3535940647125244, 5.57070779800415, 3.3531057834625244, 3.2583727836608887, 6.004110813140869, 4.8375115394592285, 3.808367967605591, 3.483140230178833, 2.9851300716400146, 4.330249786376953, 4.266688346862793, 4.105166912078857, 6.842380046844482, 4.208712577819824, 4.209193229675293, 3.9118595123291016, 3.1286373138427734, 4.076567649841309, 9.145045280456543, 3.7493338584899902, 4.906206130981445, 4.19926118850708, 5.509835243225098, 3.118542194366455, 5.179491996765137, 6.0732741355896, 3.7976174354553223, 3.6937711238861084, 4.232518196105957, 5.912298202514648, 1.0693047046661377, 5.032642364501953, 2.769211769104004, 1.0652533769607544, 4.694488048553467, 4.325316905975342, 3.778135299682617, 2.4486591815948486, 4.76910400390625, 2.93142032623291, 6.640355587005615, 4.646580219268799, 5.567437648773193, 1.103580355644226, 3.38783860206604, 1.3569515943527222, 5.060390472412109, 1.081871509552002, 3.9410414695739746, 3.7728774547576904, 3.722510814666748, 4.472904682159424, 4.012254238128662, 4.491535663604736, 2.4362876415252686, 6.117192268371582, 2.6465327739715576, 4.798973560333252, 2.7829177379608154, 4.110867023468018, 3.8785808086395264, 5.040511608123779, 4.164000034332275, 6.248891830444336, 4.259618282318115, 6.351233959197998, 3.6628689765930176, 3.999636173248291, 5.77549409866333, 6.959321975708008, 3.3205811977386475, 4.246020317077637, 5.8255696296691895, 5.537302494049072, 9.03074836730957, 4.322891712188721, 8.655973434448242, 3.5905303955078125, 4.781340599060059, 3.728151321411133, 3.599400758743286, 5.282216548919678, 4.7462334632873535, 5.630824565887451, 3.0534801483154297, 3.077700138092041, 3.8804898262023926, 3.6545238494873047, 7.374616622924805, 6.324067115783691, 3.906757354736328, 3.6091651916503906, 4.807150363922119, 3.345533609390259, 3.884528398513794, 2.662431001663208, 4.945836544036865, 4.538565158843994, 2.976179599761963, 5.867771148681641, 3.4262874126434326, 3.8844716548919678, 2.3957512378692627, 3.8877065181732178, 4.2323832511901855, 6.109151363372803, 4.5037384033203125, 5.145776271820068, 2.518064022064209, 1.5289674997329712, 2.38273286819458, 4.568329811096191, 3.9137916564941406, 3.9262444972991943, 3.605287790298462, 5.24673318862915, 5.678495407104492, 4.805059909820557, 5.912844657897949, 5.049302101135254, 6.765022277832031, 2.6358957290649414, 3.915639877319336, 7.814665794372559, 2.335576295852661, 4.244839668273926, 5.537450790405273, 6.243108749389648, 5.0204339027404785, 5.86354923248291, 3.960081100463867, 3.568795919418335, 2.912886381149292, 3.6749613285064697, 5.443187713623047, 3.767533302307129, 3.9618585109710693, 4.866428852081299, 5.384791851043701, 5.311997890472412, 2.876251697540283, 3.5621423721313477, 3.0090363025665283, 2.489631414413452, 8.011942863464355, 6.406247138977051, 6.063925266265869, 6.759958267211914, 6.596734523773193, 4.782682418823242, 4.720730781555176, 2.8384761810302734, 4.6697516441345215, 5.6317057609558105, 3.227267265319824, 6.317157745361328, 6.452974319458008, 2.9853463172912598, 4.577663898468018, 4.258042812347412, 5.264667510986328, 3.2194695472717285, 1.8372230529785156, 4.190039157867432, 5.094682216644287, 4.120458602905273, 5.888673305511475, 3.5060558319091797, 5.701857089996338, 5.276764392852783, 6.559566497802734, 4.135857582092285, 4.018261909484863, 4.036769390106201, 3.8112878799438477, 7.27813196182251, 2.8905603885650635, 5.237534999847412, 2.6313087940216064, 2.527042865753174, 4.938496112823486, 4.1103315353393555, 6.0847883224487305, 6.352822780609131, 3.214421033859253, 6.903077125549316, 4.783827304840088, 8.17324161529541, 5.297483921051025, 5.090672016143799, 3.9299521446228027, 3.6423192024230957, 5.83586311340332, 6.255082130432129, 4.811574459075928, 5.587100982666016, 4.019298553466797, 5.021966934204102, 1.8409695625305176, 4.177535533905029, 2.8846595287323, 4.032863616943359, 4.02245569229126, 6.034797191619873, 3.148261308670044, 4.9148850440979, 2.8543198108673096, 3.286680221557617, 4.34840202331543, 2.604199171066284, 3.9040310382843018, 3.2984306812286377, 4.179610729217529, 3.9831199645996094, 4.884068489074707, 4.254683971405029, 4.03407621383667, 3.9408342838287354, 1.9360982179641724, 6.208914279937744, 3.2434027194976807, 5.3845086097717285, 8.249587059020996, 3.734818458557129, 3.346466541290283, 7.016280174255371, 3.3455071449279785, 4.848718166351318, 4.218443393707275, 6.1786909103393555, 3.246546745300293], 'mean_perplexity': 4.52858386182785}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "%cd /hpc2hdd/home/xzou428/Yuhao/HiGPT-tune-lightning/\n",
    "\n",
    "import os \n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "from evaluate import load\n",
    "from pathlib import Path\n",
    "\n",
    "perplexity = load(\"perplexity\", module_type=\"metric\")\n",
    "predictions = []\n",
    "\n",
    "predict_dir = \"checkpoints/higpt-stage2-llama3-new_batch_rw-epoch30-8192-full_finetune/lightning_logs/version_9/predict\"\n",
    "predict_files = list(Path(predict_dir).glob(\"*.txt\"))\n",
    "for file in predict_files:\n",
    "    with open(file, \"r\") as f:\n",
    "        predictions.append(f.read().strip())\n",
    "\n",
    "results = perplexity.compute(predictions=predictions, model_id='/hpc2hdd/home/xzou428/Yuhao/llama3-8b-instruct', device=\"cuda\", batch_size=4)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_files"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.14 ('yuh')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "92619f85c62eb73280c07ca2268c8e47b90999a589aa097a1a08a504bd2fb2c6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
